import numpy as np
import torch
import time

import models
from utils import *
from pathlib import Path


ARGS = {'bsize':200,
        'T': 1000,
        'lr': 0.005,
        'eps': 0.005,
        'mmt': 0.9,}

N_DEBUG = 60000


def get_dataset(data_path='dataset/mnist'):
  mean = [0.1307,]
  std = [0.3081,]
  data_transform = transforms.Compose([
    transforms.ToTensor(),
    transforms.Normalize(mean, std)])
  data_train = dset.MNIST(
    data_path,
    train=True,
    transform=data_transform,
    download=True)
  data_test = dset.MNIST(
    data_path,
    train=False,
    transform=data_transform,
    download=True)
  return data_train, data_test


def compute_mean_grad(model, loss_fn, lr, data_ld, g_out):
  # Put lr * (grad loss(model, data_ld)) into g_out.
  model.zero_grad()
  # Debug
  n = 0
  for (x, y) in data_ld:
    x, y = x.cuda(), y.cuda()
    y_out = model(x)
    loss_fn(y_out, y).backward()
    n += len(y)
    if n >= N_DEBUG:
      break
  with torch.no_grad():
    for i, p in enumerate(model.parameters()):
      g_out[i].copy_(p.grad.data)
      g_out[i].mul_(lr / float(n))


def update_model(model, g):
  with torch.no_grad():
    for i, p in enumerate(model.parameters()):
      p.data.add_(g[i], alpha=-1.)


def compute_bound(n, m, T, d, eps, sum_lr2grad2):
  delta = 0.1
  bound = (np.log(1 / delta) + 3) / (n - m)
  bound += (1 / (eps ** 2.)) * (np.log(d) + np.log(T)) \
           * sum_lr2grad2 / (n - m) 
  return bound


def train(fname, args, data_train, data_prior, data_test, fgd = True):
  bsize = args['bsize']
  open(fname,'w').write('')
  n, m = len(data_train), len(data_prior)
  print(n, m)
  S_ld = get_data_ld(data_train, bsize)
  SJ_ld = get_data_ld(data_prior, bsize)
  test_ld = get_data_ld(data_test, bsize)

  model = models.Cnn(1)
  model.cuda()
  print('device_count:', torch.cuda.device_count())
  
  # Use data parallel.
  # device0 = torch.device("cuda:0")
  # model = nn.DataParallel(model)
  # model.to(device0)
  
  loss_fn = torch.nn.CrossEntropyLoss(reduction='sum')
  loss_fn.cuda()
  
  # Rand the labels
  x0, y0 = next(iter(S_ld))
  # Set model.grad to 0.
  zero_grad(model, loss_fn, x0, y0)

  # Create tensors with the same shape as model.parameters().
  p_last = clone_param(model)
  g_mmt = clone_param(model)
  g1 = clone_param(model)
  g2 = clone_param(model)
  
  d = number_of_parameter(model)
  print_log({'d': d}, fname)
  
  # Start training.
  model.train()
  sum_lr_sigma2 = 0.
  lr = args['lr']
  eps = args['eps']
  mmt = args['mmt']

  sum_lr2grad2 = lr2grad2 = 0.

  for t in range(args['T']):
    print(time.strftime('%Y-%m-%d %H:%M:%S',time.localtime(time.time())), t)
    # Logging Information.
    if t % 10 == 0:
      tr_acc = accuracy(S_ld, model)
      te_acc = accuracy(test_ld, model)
      SJ_acc = accuracy(SJ_ld, model)
      bound_1 = 1. - (tr_acc * n - SJ_acc * m) / (n - m)
      bound_2 = compute_bound(n, m, t + 1, d, eps, sum_lr2grad2)
      log_info_dic = {'step' : t,
                      'tr_acc': tr_acc,
                      'te_acc': te_acc,
                      'lr2grad2': lr2grad2,
                      'bound_1' : bound_1,
                      'bound_2': bound_2,}
      print_log(log_info_dic, fname)

    # g1 <-- lr * grad(w, S)
    compute_mean_grad(model, loss_fn, lr, S_ld, g1)
    # g2 <-- lr * grad(w, SJ)
    compute_mean_grad(model, loss_fn, lr, SJ_ld, g2)
    # g1 <-- g2 + eps * round((g1 - g2)/eps)
    vec_sub(g1, g2)
    lr2grad2 = vec_norm2(g1)
    sum_lr2grad2 += lr2grad2
    vec_mul(g1, 1. / eps)
    if fgd:
      vec_round(g1)
    vec_mul(g1, eps)
    vec_add(g1, g2)

    if mmt > 1e-4:
      # g_mmt <-- (W_{t-1} - W_{t-2}) * (-mmt)
      copy_param(g_mmt, model)
      vec_sub(g_mmt, p_last)
      vec_mul(g_mmt, -mmt)
      # g1 <-- g_mmt + g1
      vec_add(g1, g_mmt)
      copy_param(p_last, model)

    # W{t} <-- W{t-1} - g1
    update_model(model, g1)
    print(lr2grad2 / (lr ** 2.))

    if t % 150 == 0 and t > 0:
      lr = lr * 0.9
      
  sum_grad2 = sum_lr2grad2 / (lr ** 2.)
  return sum_grad2


def run_training_process(rid):
  fname = f'log/mnist/fgd/{rid}.out'
  gdname = f'log/mnist/gd/{rid}.out'
  data_train, data_test = get_dataset()
  n = len(data_train)
  m = n // 2

  J = np.random.choice(n, m, replace=False)
  data_prior = torch.utils.data.Subset(data_train, J)
  if not Path(gdname).exists():
    train(gdname, ARGS, data_train, data_prior, data_test, fgd = False)
  if not Path(fname).exists():
    train(fname, ARGS, data_train, data_prior, data_test)

def run_training_random_label(rid, p):
  fname = f'log/mnist/random_label/{p}/{rid}.out'
  data_train, data_test = get_dataset()
  n = len(data_train)
  m = n // 2
  random_label(data_train, p)
  random_label(data_test, p)

  J = np.random.choice(n, m, replace=False)
  data_prior = torch.utils.data.Subset(data_train, J)
  if not Path(fname).exists():
    train(fname, ARGS, data_train, data_prior, data_test)

def study_m_graddiff(rid):
  fname = f'log/mnist/mgrad/{rid}.out'
  if Path(fname).exists():
    return
  data_train, data_test = get_dataset()
  n = len(data_train)
  open(fname, 'w').write('')
  for _m in range(1, 10):
    m = _m * 1000
    J = np.random.choice(n ,m, replace=False)
    data_prior = torch.utils.data.Subset(data_train, J)
    sum_grad2 = train(f'log/mnist/debug.out', ARGS, 
                         data_train, data_prior, data_test)
    print_log({'m': m, 'sum_grad2': sum_grad2}, fname)


if __name__ == '__main__':
  for rid in range(0, 100):
    # FGD vs GD
    run_training_process(rid)

    # Random Labels
    run_training_random_label(rid, 0.1)
    run_training_random_label(rid, 0.2)
    run_training_random_label(rid, 1.0)

    # gradient difference decreases as m increases
    study_m_graddiff(rid)